Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Add TIR While node #7425

Merged
merged 29 commits into from
Mar 3, 2021
Merged

[TIR] Add TIR While node #7425

merged 29 commits into from
Mar 3, 2021

Conversation

masahi
Copy link
Member

@masahi masahi commented Feb 9, 2021

This is an implementation of TIR While node as discussed in RFC https://discuss.tvm.apache.org/t/rfc-add-while-loop-node-to-tir/9028. It supercedes my earlier attempt in #7385.

The PR consists of

  • IR node definition + boilerplate
  • Minimal changes to TIR transform passes (so far only modifies storage_rewrite.cc, everything else uses the default visitor)
  • LLVM and C source codegen
  • Various test cases
  • Update CUDA NMS to use while loop

Hybrid script support etc are left for future work.

Now we can write binary search succinctly as follows:

lo[0] = 0
hi[0] = n
v = B[i]

with ib.while_loop(lo[0] < hi[0]):
    mid = lo[0] + (hi[0] - lo[0] >> 1)
    with ib.if_scope(A[mid] < v):
        lo[0] = mid + 1
    with ib.else_scope():
        hi[0] = mid

C[i] = lo[0]

As another nice use of while loop, I added a test that draws a useless mandelbrot set 🙂

mandel

@tqchen @junrushao1994 @vinx13 @mbrookhart @zhiics @kevinthesun @anijain2305 @trevor-m

Copy link
Contributor

@giuseros giuseros left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, very nice addition!

@tqchen
Copy link
Member

tqchen commented Feb 9, 2021

Thanks @masahi , before we merge it in. it would be really awesome to go through the current list of passes and check if special handling of while is needed (so we won't bring in new bugs because the mix). Some of the example passes could include (I would at least check passes that need special IfThenElse handling)

For example, I can see the need to update following pass:

  • Vectorize (we will need to abort if the condition is vectorized)

@tqchen
Copy link
Member

tqchen commented Feb 9, 2021

also cc @zxybazh please help to review this PR

Copy link
Member

@junrushao junrushao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! It looks good to me :-) Surprisingly it doesn't need to change any passes besides storage_rewrite :-)

@junrushao
Copy link
Member

CC @spectrometerHBH: we might want to have it supported in TensorIR too, either like a syntactic sugar to opaque binding or other ways

@@ -109,6 +110,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
IR_STMT_FUNCTOR_DISPATCH(AttrStmtNode);
IR_STMT_FUNCTOR_DISPATCH(IfThenElseNode);
IR_STMT_FUNCTOR_DISPATCH(ForNode);
IR_STMT_FUNCTOR_DISPATCH(WhileNode);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need checks through the current passes, per my comment

Copy link
Member

@zxybazh zxybazh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @masahi! Looks good to me.

@masahi
Copy link
Member Author

masahi commented Feb 10, 2021

@tqchen @junrushao1994 @vinx13

I went through the passes and here is my summary:

  • VectorizeLoop: Need to disallow a while loop inside a vectorized loop. Without it, no errors occurs during lowering but the lowered code is incorrect. Add a test case test_vectorize_while_fail() to make sure we error out in such cases

  • StorageAccessVisitor: I don't understand what it does, but added a special visitor for While following the existing visitor for IfThenElse. Please check 1e629b6

  • CoProcSync and LiftAttrScope: They both have special visitor for IfThenElse, but I don't understand them. They are only used by VTA, for now I just error out if we find WhileNode there. See a71066d and 00c17d9

  • InjectVirtualThread: I think we need some special handling for this, but I don't know what it should be. For now I just added a placeholder and call the base class visitor. See 896b02f and let me know what we should do here.

  • Do we need to change MergeNest? I haven't touched it for now

    Stmt MergeNest(const std::vector<Stmt>& nest, Stmt body) {
    // use reverse iteration
    for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) {
    Stmt s = *ri;
    if (const auto* for_ = s.as<ForNode>()) {
    auto n = make_object<ForNode>(*for_);
    ICHECK(is_no_op(n->body));
    n->body = body;
    body = Stmt(n);
    } else if (const auto* let = s.as<LetStmtNode>()) {
    auto n = make_object<LetStmtNode>(*let);
    ICHECK(is_no_op(n->body));
    n->body = body;
    body = Stmt(n);
    } else if (const auto* attr = s.as<AttrStmtNode>()) {
    auto n = make_object<AttrStmtNode>(*attr);
    ICHECK(is_no_op(n->body));
    n->body = body;
    body = Stmt(n);
    } else if (const auto* ite = s.as<IfThenElseNode>()) {
    auto n = make_object<IfThenElseNode>(*ite);
    ICHECK(is_no_op(n->then_case));
    ICHECK(!n->else_case.defined());
    n->then_case = body;
    body = Stmt(n);

  • Probably we don't need to change hoist_if_then_else.cc and loop_partition.cc. We can do something in remove_no_op.cc, but I think it is not important.

@masahi
Copy link
Member Author

masahi commented Feb 12, 2021

@tqchen Can you have a look?

@tqchen
Copy link
Member

tqchen commented Feb 12, 2021

I left a comment for inject virtual thread, @junrushao1994 @ZihengJiang @vinx13 would be great if you can also help check the StorageAccessVisitor

@vinx13
Copy link
Member

vinx13 commented Feb 12, 2021

I've checked StorageAccessVisitor and it looks good to me.InplaceOpVerifier, StoragePlanRewriter also need handling.

@masahi
Copy link
Member Author

masahi commented Feb 12, 2021

@vinx13 Ok, For InplaceOpVerifier I think I need to update

if (stmt->IsInstance<AttrStmtNode>()) {
VisitStmt_(static_cast<const AttrStmtNode*>(stmt));
} else if (stmt->IsInstance<ForNode>()) {
VisitStmt_(static_cast<const ForNode*>(stmt));
} else if (stmt->IsInstance<IfThenElseNode>()) {
VisitStmt_(static_cast<const IfThenElseNode*>(stmt));
} else if (stmt->IsInstance<StoreNode>()) {
VisitStmt_(static_cast<const StoreNode*>(stmt));
} else {
return false;
}

But I don't see how we should update StoragePlanRewriter. Maybe here?

// enter/exit new scope
if (s.stmt->IsInstance<AttrStmtNode>()) {
const auto* op = static_cast<const AttrStmtNode*>(s.stmt);
if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread ||
attr::IsPragmaKey(op->attr_key)) {
PlanNewScope(op);
} else {
ICHECK(op->attr_key == attr::extern_scope);
}
} else if (s.stmt->IsInstance<ForNode>()) {
const auto* op = static_cast<const ForNode*>(s.stmt);
if (op->kind == ForKind::kParallel) {
if (thread_scope_ == nullptr || thread_scope_ == op) {
PlanNewScope(op);
}
}
}

@tqchen
Copy link
Member

tqchen commented Feb 12, 2021

Thanks @masahi , it would also be great for you to spend a bit more time to look into these passes :) It certainly takes more time, but we will also have more experts in TIR passes :)

Please also consider to add a test case to the passes that need while handling

@vinx13
Copy link
Member

vinx13 commented Feb 12, 2021

@masahi For StoragePlanRewriter, we need to do something similar to ForNode

Stmt VisitStmt_(const ForNode* op) final {
ICHECK(op->kind != ForKind::kVectorized) << "VectorizeLoop before LiftStorageAlloc";
// remake all the allocation at the attach scope.
if (attach_map_.count(op)) {
auto& svec = attach_map_[op];
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<ForNode>();
return For(op->loop_var, op->min, op->extent, op->kind, MakeAttach(svec, op->body),
op->thread_binding, op->annotations);
} else {
return StmtExprMutator::VisitStmt_(op);
}
}

@masahi
Copy link
Member Author

masahi commented Feb 12, 2021

ok, to me it's not obvious what it is doing, time for another deep dive...

@masahi masahi marked this pull request as draft February 15, 2021 22:19
@masahi masahi marked this pull request as ready for review February 15, 2021 22:57
@masahi
Copy link
Member Author

masahi commented Feb 16, 2021

@tqchen @vinx13 @junrushao1994 Does the behavior of While node wrt StorageRewrite below look reasonable?

In the following IR, "A" and "B" buffers, which are allocated in For loop, are coalesced into a one buffer, but "C" buffer, which is allocated inside While loop, is not:

def test_parallel_alloc():
    ib = tvm.tir.ir_builder.create()
    n = te.var("n")
    with ib.for_range(0, n, name="i", kind="parallel") as i:
        with ib.for_range(0, 10, name="j") as j:
            A = ib.allocate("float32", n, name="A", scope="global")
            A[j] = A[j] + 2

        with ib.for_range(0, 10, name="j") as j:
            B = ib.allocate("float32", n, name="B", scope="global")
            B[j] = B[j] + 2

        i = ib.allocate("int32", (1,), name="i", scope="local")
        i[0] = 1
        with ib.while_loop(i[0] < 10):
            C = ib.allocate("float32", n, name="C", scope="local")
            C[i[0]] = C[i[0]] + 2
            i[0] += 1
parallel (i, 0, n) {
  // attr [A] storage_scope = "global"
  allocate A[float32 * n]
  // attr [i] storage_scope = "local"
  allocate i[int32 * 1]
  // attr [C] storage_scope = "local"
  allocate C[float32 * n]
  for (j, 0, 10) {
    A[j] = (A[j] + 2f)
  }
  for (j, 0, 10) {
    A[j] = (A[j] + 2f)
  }
  i[0] = 1
  while((i[0] < 10)){
    C[i[0]] = (C[i[0]] + 2f)
    i[0] = (i[0] + 1)
  }
}

In the following IR, all buffers, including the one allocated inside While loop, are coalesced:

def test_alloc_seq():
    scope_tb = "local.L0A"
    max_bits = 1024 * 1024 * 1024

    register_mem(scope_tb, max_bits)

    ib = tvm.tir.ir_builder.create()
    n = te.var("n")
    with ib.for_range(0, n, name="i") as i:
        with ib.for_range(0, 10, name="j") as j:
            A = ib.allocate("float32", 200, name="A", scope=scope_tb)
            A[j] = 1.2
        with ib.for_range(0, 10, name="j") as j:
            B = ib.allocate("float32", 200, name="B", scope=scope_tb)
            B[j] = 1.3

        i = ib.allocate("int32", (1,), name="i", scope="local")
        i[0] = 1
        with ib.while_loop(i[0] < 10):
            C = ib.allocate("float32", 200, name="C", scope=scope_tb)
            C[i[0]] = 1.4
            i[0] += 1

    body = ib.get()
// attr [A] storage_scope = "local.L0A"
allocate A[float32 * 200]
// attr [i] storage_scope = "local"
allocate i[int32 * 1]
for (i, 0, n) {
  for (j, 0, 10) {
    A[j] = 1.2f
  }
  for (j, 0, 10) {
    A[j] = 1.3f
  }
  i[0] = 1
  while((i[0] < 10)){
    A[i[0]] = 1.4f
    i[0] = (i[0] + 1)
  }
}


@tqchen
Copy link
Member

tqchen commented Feb 22, 2021

@vinx13 can you please take another look at the PR and manage?

@tqchen tqchen added status: need review status: need test case need test cases to cover the change labels Feb 22, 2021
Copy link
Member

@tqchen tqchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @masahi ! the change has addressed my previous comments. Please add testcases to transforms that touches requires special While handling to cover these passes

Copy link
Contributor

@ZihengJiang ZihengJiang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work! Thanks @masahi

@masahi
Copy link
Member Author

masahi commented Mar 2, 2021

@tqchen @junrushao1994 @vinx13 @ZihengJiang @zxybazh

I came to a conclusion that While node doesn't need a special handling in storage_rewrite.

The first observation is that even if I remove all ForNode handling from StoragePlanRewriter, all tests in test_tir_transform_storage_rewrite.py except test_parallel_alloc() pass.

If we look at the visitor for ForNode,

Stmt VisitStmt_(const ForNode* op) final {
ICHECK(op->kind != ForKind::kVectorized) << "VectorizeLoop before LiftStorageAlloc";
// remake all the allocation at the attach scope.
if (attach_map_.count(op)) {
auto& svec = attach_map_[op];
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<ForNode>();
return For(op->loop_var, op->min, op->extent, op->kind, MakeAttach(svec, op->body),
op->thread_binding, op->annotations);
} else {
return StmtExprMutator::VisitStmt_(op);
}
}

it only does something special when attach_map_ has an entry for this node. Here comes the second observation: the only case whereattach_map_ can have an entry for ForNode is if this ForNode is a parallel for loop, due to these lines:
} else if (s.stmt->IsInstance<ForNode>()) {
const auto* op = static_cast<const ForNode*>(s.stmt);
if (op->kind == ForKind::kParallel) {
if (thread_scope_ == nullptr || thread_scope_ == op) {
PlanNewScope(op);
}
}

Together, these two handler for ForNode lift allocation inside an inner loop and attach merged allocation under the parallel loop scope (via MakeAttach function at

return For(op->loop_var, op->min, op->extent, op->kind, MakeAttach(svec, op->body),
). This is what's tested in test_parallel_alloc(). For other kinds of For loop, a merged allocation is placed at the global scope, see
struct StorageEntry {
// The scope that this alloc attaches after
// For shared/local memory it is beginning of the thread extent.
// for global memory it is nullptr, means beginning of everything.
const Object* attach_scope_{nullptr};
.

Since While node doesn't involve threading, I think we can always lift allocation done inside While loop into the global scope. That means WhileNode should be handled in the same way non-parallel ForNode are handled, i.e. we don't need a special handling logic for WhileNode. Two simple test cases involving While loop are added in

to test allocation is attached at the right scope after storage_rewrite.

I think I nailed it, thoughts?

@vinx13
Copy link
Member

vinx13 commented Mar 2, 2021

@masahi You are right, thanks for looking into this

@junrushao
Copy link
Member

That makes sense to me. Thanks for diving deep into this issue!

@masahi
Copy link
Member Author

masahi commented Mar 2, 2021

cc @tqchen please take a look

@tqchen
Copy link
Member

tqchen commented Mar 2, 2021

@masahi you are right that the MakeAttach is only needed for parallel for loop, where we can nolonger lift the memory to the outside(otherwise the memory won't be thread local)

@tqchen
Copy link
Member

tqchen commented Mar 2, 2021

@junrushao1994 @vinx13 please help to manage the PR

Copy link
Member

@tqchen tqchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One minor comment

tests/python/unittest/test_tir_ir_builder.py Outdated Show resolved Hide resolved
@masahi
Copy link
Member Author

masahi commented Mar 3, 2021

@junrushao1994 @vinx13 @tqchen ready to merge...!!

@vinx13 vinx13 merged commit cf36aa6 into apache:main Mar 3, 2021
@vinx13
Copy link
Member

vinx13 commented Mar 3, 2021

@vinx13 vinx13 added status: accepted and removed status: need review status: need test case need test cases to cover the change status: need update need update based on feedbacks labels Mar 3, 2021
@junrushao
Copy link
Member

Really awesome work!!!

@masahi
Copy link
Member Author

masahi commented Mar 3, 2021

Thank you very much for the reviews!!

trevor-m pushed a commit to trevor-m/tvm that referenced this pull request May 6, 2021
* add while node

* update visitors

* binary search lowering works

* llvm codegen working

* cuda codegen working

* nms updated to use while loop

* add missing upper bound check too

* add mandelbrot test

* add gpu mandel

commit ee2363b
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Fri Jan 29 11:44:02 2021 +0900

    enable extern lib offload for nvptx

* rename test

* run black

* add doc

* add collatz test

* add while + vectorize test

* simplify bin search

* Add special case visit method to storage_access.cc

* disallow while loop inside vectorized loop

* disallow trivial condition since we do not have break

* error out in CoprocSync for now

* error out LiftAttrScope for now

* add placeholder to inject_vpthread

* refactor to use MakeAttach

* handle WhileNode in InplaceOpVerifier

* error out in InjectVirtualThread

* try handle WhileNode in StoragePlanRewriter

* remove WhileNode visitor from storage rewrite

* add while loop storage rewrite test

* update tests

* move test_vectorize_while_fail to  test_tir_transform_vectorize.py
trevor-m pushed a commit to neo-ai/tvm that referenced this pull request May 11, 2021
* add while node

* update visitors

* binary search lowering works

* llvm codegen working

* cuda codegen working

* nms updated to use while loop

* add missing upper bound check too

* add mandelbrot test

* add gpu mandel

commit ee2363b
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Fri Jan 29 11:44:02 2021 +0900

    enable extern lib offload for nvptx

* rename test

* run black

* add doc

* add collatz test

* add while + vectorize test

* simplify bin search

* Add special case visit method to storage_access.cc

* disallow while loop inside vectorized loop

* disallow trivial condition since we do not have break

* error out in CoprocSync for now

* error out LiftAttrScope for now

* add placeholder to inject_vpthread

* refactor to use MakeAttach

* handle WhileNode in InplaceOpVerifier

* error out in InjectVirtualThread

* try handle WhileNode in StoragePlanRewriter

* remove WhileNode visitor from storage rewrite

* add while loop storage rewrite test

* update tests

* move test_vectorize_while_fail to  test_tir_transform_vectorize.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants